package org.apache.spark.sql.execution.datasources.restjson.demo2

import java.net.URL
import java.util.function.Consumer

import com.google.gson.JsonParser
import org.apache.http.client.fluent.Request
import org.apache.http.util.EntityUtils
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SQLContext, SparkSession}

import scala.collection.mutable.ArrayBuffer

trait JsonProcessor {
  // 将 scala 的 lambda 表达式转成 java 的 Consumer 便于对 java 的 iterator 进行 foreach
  implicit def toConsumer[A](function: A => Unit): Consumer[A] = new Consumer[A]() {
    override def accept(arg: A): Unit = function.apply(arg)
  }

  def parse(content: String): ArrayBuffer[Row] = {
    val rows = ArrayBuffer[Row]()
    rows += Row.fromSeq(Seq(content))
    rows
  }

  def defaultSchema = StructType(Seq(StructField("body", StringType, true)))
}

private[sql] class RestJSONRelation(val path: String, val structType: StructType, paser: JsonProcessor)(@transient val sqlContext: SQLContext)
  extends BaseRelation with TableScan {
  override def schema: StructType = structType

  private def createBaseRdd(inputPaths: Array[String]): RDD[Row] = {
    val url = inputPaths.head
    val res = Request.Get(new URL(url).toURI).execute()
    val response = res.returnResponse()
    val content = EntityUtils.toString(response.getEntity)
    if (response != null && response.getStatusLine.getStatusCode == 200) {
      val rows = paser.parse(content)
      sqlContext.sparkContext.makeRDD(rows)
    } else {
      sqlContext.sparkContext.makeRDD(Seq())
    }
  }

  override def buildScan(): RDD[Row] = {
    createBaseRdd(Array(path))
  }
}

class DefaultSource extends RelationProvider with DataSourceRegister with JsonProcessor with Logging
  with Serializable {
  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {

    val path = parameters.getOrElse("path", "")
    new RestJSONRelation(path, defaultSchema, this)(sqlContext)
  }

  override def shortName(): String = "restJson"
}


class DefaultSourceNeedSchema extends SchemaRelationProvider with DataSourceRegister with JsonProcessor with Logging
  with Serializable {
  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String],
                              schema: StructType): BaseRelation = {
    val path = parameters.getOrElse("path", "")
    if (schema != null) {
      new RestJSONRelation(path, schema, this)(sqlContext)
    } else {
      new RestJSONRelation(path, defaultSchema, this)(sqlContext)
    }

  }

  override def shortName(): String = "restJson"

}

/**
  * 添加自定义的业务处理逻辑
  */
class MyDataSource extends DefaultSource {
  override def parse(content: String) = {
    val rows = ArrayBuffer[Row]()
    //这里是做数据抽取的，把data的数组给抽取出来,业务逻辑耦合
    (new JsonParser()).parse(content).getAsJsonObject.get("files").getAsJsonArray.forEach(toConsumer(x => {
      val obj = x.getAsJsonObject
      rows += Row.fromSeq(Seq(obj.get("name").getAsString, obj.get("directory").getAsBoolean))
    }))
    rows
  }

  override def defaultSchema: StructType = StructType(Seq(StructField("name", StringType, true), StructField("directory", BooleanType, true)))
}

/**
  * Created by peibin on 2017/8/4.
  * 自定义 spark sql data source api
  */
object RestSourceDemo {

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("RestSource")
      .master("local[*]")
      .getOrCreate()


    spark.read.format("org.apache.spark.sql.execution.datasources.restjson.demo2.MyDataSource").options(Map(
      "path" -> "http://10.199.212.80:8400/esperhqapp/hqapi/vfs/file"
    )).load().show

    spark.read.format("org.apache.spark.sql.execution.datasources.restjson.demo2.DefaultSource").options(Map(
      "path" -> "http://10.199.212.80:8400/esperhqapp/hqapi/vfs/file"
    )).load().show()

    spark.read.format("org.apache.spark.sql.execution.datasources.restjson.demo2").options(Map(
      "path" -> "http://10.199.212.80:8400/esperhqapp/hqapi/vfs/file"
    )).load().show()

    spark.read.format("org.apache.spark.sql.execution.datasources.restjson.demo2.DefaultSourceNeedSchema").options(Map(
      "path" -> "http://10.199.212.80:8400/esperhqapp/hqapi/vfs/file"
    )).schema(StructType(Seq(StructField("body", StringType, true)))).load().show()
  }


}
